{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02121.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02121.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C9vYXrmyqw-s",
        "colab_type": "text"
      },
      "source": [
        "### 10.3 word2vec 的实现"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sKiJAasDsIRh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import collections \n",
        "import math \n",
        "import random \n",
        "import time \n",
        "import os \n",
        "import numpy as np \n",
        "import torch \n",
        "from torch import nn \n",
        "import torch.utils.data as Data \n",
        "import d2l \n",
        "import zipfile"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o-dt8C3CwsVt",
        "colab_type": "text"
      },
      "source": [
        "### 10.3.1 处理数据集"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8gWbm-6gw3Ol",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "ebd04609-d62a-4398-bdc5-e94f884c0a57"
      },
      "source": [
        "!mkdir ../data"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "mkdir: cannot create directory ‘../data’: File exists\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5zaoXj0qwxqM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!git clone https://github.com/d2l-ai/d2l-zh.git"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FdPCcfpPw1Q9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!cp d2l-zh/data/ptb.zip ../data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YkJAS2S1xfWv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "with zipfile.ZipFile('../data/ptb.zip', 'r') as zin:\n",
        "    zin.extractall('../data/')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iy4-Lz-VyA1z",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "9d07cff3-a3cf-4dab-c846-419d0d75ce30"
      },
      "source": [
        "with open('../data/ptb/ptb.train.txt', 'r') as f:\n",
        "    lines = f.readlines()\n",
        "    raw_dataset = [st.split() for st in lines]\n",
        "\n",
        "'# sentences: %d' % len(raw_dataset)"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'# sentences: 42068'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qqoaplDWzGo-",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "23efb993-f1b7-4861-8784-1d07ff52e8a3"
      },
      "source": [
        "for st in raw_dataset[:3]:\n",
        "    print('# tokens:', len(st), st[:5])"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "# tokens: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']\n",
            "# tokens: 15 ['pierre', '<unk>', 'N', 'years', 'old']\n",
            "# tokens: 11 ['mr.', '<unk>', 'is', 'chairman', 'of']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XY6kl39ozRKZ",
        "colab_type": "text"
      },
      "source": [
        "#### 1 建立词语索引\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iSzvl_rxzZZO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "counter = collections.Counter([tk for st in raw_dataset for tk in st])\n",
        "counter = dict(filter(lambda x: x[1] >= 5, counter.items()))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fgk8u-yTzxGs",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "3e73776f-fb28-40ae-99ea-b1dc537ab972"
      },
      "source": [
        "idx_to_token = [tk for tk, _ in counter.items()]\n",
        "token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}\n",
        "dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx] for st in raw_dataset]\n",
        "num_tokens = sum([len(st) for st in dataset])\n",
        "'# tokens: %d' % num_tokens"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'# tokens: 887100'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jl6D5H841V7M",
        "colab_type": "text"
      },
      "source": [
        "#### 2 二次采样"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mxWDiHJ11a-1",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "b764454c-fb0d-4337-ca42-4ebc38d3f550"
      },
      "source": [
        "def discard(idx):\n",
        "    return random.uniform(0, 1) < 1 - math.sqrt(1e-4 / counter[idx_to_token[idx]] * num_tokens)\n",
        "\n",
        "subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]\n",
        "'# tokens: %d' % sum([len(st) for st in subsampled_dataset])"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'# tokens: 375563'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "K96s46ei2OA1",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "27af8e2e-6213-41fc-fad6-ee2908167269"
      },
      "source": [
        "def compare_counts(token):\n",
        "    return '# %s: before=%d, after=%d' %(token, \n",
        "        sum([st.count(token_to_idx[token]) for st in dataset]), \n",
        "        sum([st.count(token_to_idx[token]) for st in subsampled_dataset]))\n",
        "\n",
        "compare_counts('the')"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'# the: before=50770, after=2081'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "w24Sc8Sw3G-I",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "596647cc-1c69-4156-81b2-4c84b5e9cdb5"
      },
      "source": [
        "compare_counts('join')"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'# join: before=45, after=45'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TIIKCP0u3LeV",
        "colab_type": "text"
      },
      "source": [
        "#### 3 提取中心词和背景词"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_FnxE_ef3TOj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_centers_and_contexts(dataset, max_window_size):\n",
        "    centers, contexts = [], []\n",
        "    for st in dataset:\n",
        "        if len(st) < 2:\n",
        "            continue \n",
        "        centers += st \n",
        "        for center_i in range(len(st)):\n",
        "            window_size = random.randint(1, max_window_size)\n",
        "            indices = list(range(max(0, center_i - window_size), \n",
        "                        min(len(st), center_i + 1 + window_size)))\n",
        "            indices.remove(center_i)\n",
        "            contexts.append([st[idx] for idx in indices])\n",
        "    return centers, contexts"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6WCOye3o4ZMy",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 202
        },
        "outputId": "9bf8c3f2-1f58-4936-8123-9f8dc63a7602"
      },
      "source": [
        "tiny_dataset = [list(range(7)), list(range(7, 10))]\n",
        "print('dataset', tiny_dataset)\n",
        "for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):\n",
        "    print('center', center, 'has contexts', context)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]\n",
            "center 0 has contexts [1]\n",
            "center 1 has contexts [0, 2, 3]\n",
            "center 2 has contexts [1, 3]\n",
            "center 3 has contexts [1, 2, 4, 5]\n",
            "center 4 has contexts [3, 5]\n",
            "center 5 has contexts [4, 6]\n",
            "center 6 has contexts [4, 5]\n",
            "center 7 has contexts [8, 9]\n",
            "center 8 has contexts [7, 9]\n",
            "center 9 has contexts [7, 8]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JE0VuJN_4zDR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f58_Kwv75GXt",
        "colab_type": "text"
      },
      "source": [
        "### 10.3.2 负采样"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cGqmLbPK5MQ-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_negatives(all_contexts, sampling_weights, K):\n",
        "    all_negatives, neg_candidates, i = [], [], 0\n",
        "    population = list(range(len(sampling_weights)))\n",
        "    for contexts in all_contexts:\n",
        "        negatives = []\n",
        "        while len(negatives) < len(contexts) * K:\n",
        "            if i == len(neg_candidates):\n",
        "                i, neg_candidates = 0, random.choices(population, sampling_weights, k=int(1e5))\n",
        "            neg, i = neg_candidates[i], i + 1 \n",
        "            if neg not in set(contexts):\n",
        "                negatives.append(neg)\n",
        "        all_negatives.append(negatives)\n",
        "    return all_negatives \n",
        "\n",
        "sampling_weights = [counter[w]**0.75 for w in idx_to_token]\n",
        "all_negatives = get_negatives(all_contexts, sampling_weights, 5)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wjoxe5_I6vbp",
        "colab_type": "text"
      },
      "source": [
        "### 10.3.3 读取数据"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jquXdGcd7B_T",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class MyDataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, centers, contexts, negatives):\n",
        "        assert len(centers) == len(contexts) == len(negatives)\n",
        "        self.centers = centers \n",
        "        self.contexts = contexts \n",
        "        self.negatives = negatives \n",
        "\n",
        "    def __getitem__(self, index):\n",
        "        return (self.centers[index], self.contexts[index], self.negatives[index])\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.centers)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DQq2Adm27-1P",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def batchify(data):\n",
        "    max_len = max(len(c) + len(n) for _, c, n in data)\n",
        "    centers, contexts_negatives, masks, labels = [], [], [], []\n",
        "    for center, context, negative in data:\n",
        "        cur_len = len(context) + len(negative)\n",
        "        centers += [center]\n",
        "        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]\n",
        "        masks += [[1] * cur_len + [0] * (max_len - cur_len)]\n",
        "        labels += [[1] * len(context) + [0] * (max_len - len(context))]\n",
        "    return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives), \n",
        "        torch.tensor(masks), torch.tensor(labels))\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XsJw7AeJ9S9m",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        },
        "outputId": "7221de9c-4378-4b9d-b964-94d54cdc2a64"
      },
      "source": [
        "batch_size = 512 \n",
        "num_workers = 4 \n",
        "dataset = MyDataset(all_centers, all_contexts, all_negatives)\n",
        "data_iter = Data.DataLoader(dataset, batch_size, shuffle=True, collate_fn=batchify, num_workers=num_workers)\n",
        "for batch in data_iter:\n",
        "    for name, data in zip(['centers', 'contexts_negatives', 'masks', 'labels'], batch):\n",
        "        print(name, 'shape:', data.shape)\n",
        "    break "
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "centers shape: torch.Size([512, 1])\n",
            "contexts_negatives shape: torch.Size([512, 60])\n",
            "masks shape: torch.Size([512, 60])\n",
            "labels shape: torch.Size([512, 60])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gsILHnha-IY5",
        "colab_type": "text"
      },
      "source": [
        "### 10.3.4 跳字模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5IqdxHGn-WdU",
        "colab_type": "text"
      },
      "source": [
        "#### 1 嵌入层"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8Z6CDEi_-Zi0",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 370
        },
        "outputId": "1a0f822c-afa6-4d01-c37c-fc28264cf3a3"
      },
      "source": [
        "embed = nn.Embedding(num_embeddings=20, embedding_dim=4)\n",
        "embed.weight"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Parameter containing:\n",
              "tensor([[ 0.5857, -0.8756,  0.6874, -1.0600],\n",
              "        [ 2.3237,  1.0624,  1.2328,  0.6584],\n",
              "        [ 0.7717,  1.0390,  1.4787, -0.3867],\n",
              "        [-1.0761,  0.2477,  0.4813,  0.3817],\n",
              "        [-1.1075,  0.2145,  0.5586,  1.1737],\n",
              "        [ 0.1053, -0.2418, -0.2784,  0.1152],\n",
              "        [ 0.6647,  0.2206,  1.9034, -0.0731],\n",
              "        [-0.6615,  0.4606, -1.0465,  1.5683],\n",
              "        [ 1.1895,  1.4421, -0.8648,  0.6017],\n",
              "        [-0.6986, -1.2590,  0.7220,  0.0954],\n",
              "        [ 0.8325,  1.0354,  0.1956, -0.0661],\n",
              "        [-0.7696,  0.2404, -0.7201, -0.0711],\n",
              "        [-0.1542,  0.1029,  0.1461, -0.5842],\n",
              "        [-0.2405,  0.3505,  0.0712,  1.3604],\n",
              "        [ 0.1141,  0.8417, -0.8556, -0.0872],\n",
              "        [-0.8149, -0.5558,  0.9528,  0.7065],\n",
              "        [ 0.3150,  0.1334, -0.9347, -0.4403],\n",
              "        [ 0.3200,  1.0298,  0.5296, -0.1633],\n",
              "        [-1.1286,  0.8492, -1.5572, -0.6401],\n",
              "        [ 1.3195, -1.2292,  0.6236, -0.1975]], requires_grad=True)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7bym0gY9-lJE",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 134
        },
        "outputId": "f2bc550a-97d5-4791-fb2d-75c41fb59ea8"
      },
      "source": [
        "x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.long)\n",
        "embed(x)"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[[ 2.3237,  1.0624,  1.2328,  0.6584],\n",
              "         [ 0.7717,  1.0390,  1.4787, -0.3867],\n",
              "         [-1.0761,  0.2477,  0.4813,  0.3817]],\n",
              "\n",
              "        [[-1.1075,  0.2145,  0.5586,  1.1737],\n",
              "         [ 0.1053, -0.2418, -0.2784,  0.1152],\n",
              "         [ 0.6647,  0.2206,  1.9034, -0.0731]]], grad_fn=<EmbeddingBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 21
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K9Bjx7iK-yE4",
        "colab_type": "text"
      },
      "source": [
        "#### 2 小批量乘法"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XWRMxjT--3Y3",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "3b75faff-6ddf-468c-e1ec-a92ccce49499"
      },
      "source": [
        "X = torch.ones((2, 1, 4))\n",
        "Y = torch.ones((2, 4, 6))\n",
        "torch.bmm(X, Y).shape "
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([2, 1, 6])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NiwEtrlV_BT8",
        "colab_type": "text"
      },
      "source": [
        "#### 3 跳字模型前向计算"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-AQ0u_c6_Isp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def skip_gram(center, contexts_and_negatives, embed_v, embed_u):\n",
        "    v = embed_v(center)\n",
        "    u = embed_u(contexts_and_negatives)\n",
        "    pred = torch.bmm(v, u.permute(0, 2, 1))\n",
        "    return pred"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nm6_GBH5_fqL",
        "colab_type": "text"
      },
      "source": [
        "### 10.3.5 训练模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_o2IEvgX_o2M",
        "colab_type": "text"
      },
      "source": [
        "#### 1 二元交叉熵损失函数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LTIeMjZh_vHA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class SigmoidBinaryCrossEntropyLoss(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(SigmoidBinaryCrossEntropyLoss, self).__init__()\n",
        "    def forward(self, inputs, targets, mask=None):\n",
        "        inputs, targets, mask = inputs.float(), targets.float(), mask.float()\n",
        "        res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none', weight=mask)\n",
        "        return res.mean(dim=1)\n",
        "\n",
        "loss = SigmoidBinaryCrossEntropyLoss()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "v-TE3f_aAuvz",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "d765101f-5921-434f-9657-f5c3c4d42060"
      },
      "source": [
        "pred = torch.tensor([[1.5, 0.3, -1, 2], [1.1, -0.6, 2.2, 0.4]])\n",
        "label = torch.tensor([[1, 0, 0, 0], [1, 1, 0, 0]])\n",
        "mask = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 0]])\n",
        "loss(pred, label, mask) * mask.shape[1] / mask.float().sum(dim=1)"
      ],
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([0.8740, 1.2100])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 25
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vk1PnHwwBRmv",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "39641936-3ca6-403c-989a-7afe99fcbc13"
      },
      "source": [
        "def sigmd(x):\n",
        "    return -math.log(1 / (1 + math.exp(-x)))\n",
        "\n",
        "print('%.7f' % ((sigmd(1.5) + sigmd(-0.3) + sigmd(1) + sigmd(-2)) / 4))\n",
        "print('%.4f' % ((sigmd(1.1) + sigmd(-0.6) + sigmd(-2.2)) / 3))"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0.8739896\n",
            "1.2100\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kRSVQKJgCHCi",
        "colab_type": "text"
      },
      "source": [
        "#### 2 初始化模型参数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BC0GFNtdCiyx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "embed_size = 100 \n",
        "net = nn.Sequential(\n",
        "    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size), \n",
        "    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size), \n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eGkixE6CC7Wh",
        "colab_type": "text"
      },
      "source": [
        "#### 3 定义训练函数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2h1HEzYqC_6R",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def train(net, lr, num_epochs):\n",
        "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "    print('training on', device)\n",
        "    net = net.to(device)\n",
        "    optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
        "    for epoch in range(num_epochs):\n",
        "        start, l_sum, n = time.time(), 0.0, 0 \n",
        "        for batch in data_iter:\n",
        "            center, context_negative, mask, label = [d.to(device) for d in batch]\n",
        "            pred = skip_gram(center, context_negative, net[0], net[1])\n",
        "            l = (loss(pred.view(label.shape), label, mask) * \n",
        "                 mask.shape[1] / mask.float().sum(dim=1)).mean()\n",
        "            optimizer.zero_grad()\n",
        "            l.backward()\n",
        "            optimizer.step()\n",
        "            l_sum += l.cpu().item()\n",
        "            n += 1 \n",
        "        print('epoch %d, loss %.2f, time %.2fs' \n",
        "          % (epoch + 1, l_sum / n, time.time() - start))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "q3dvu0uqFIU-",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 202
        },
        "outputId": "b9dc891c-feb1-48a8-9de6-5b16a95130da"
      },
      "source": [
        "train(net, 0.01, 10)"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "training on cuda\n",
            "epoch 1, loss 1.97, time 6.20s\n",
            "epoch 2, loss 0.62, time 6.16s\n",
            "epoch 3, loss 0.45, time 6.20s\n",
            "epoch 4, loss 0.40, time 6.09s\n",
            "epoch 5, loss 0.37, time 6.01s\n",
            "epoch 6, loss 0.35, time 6.11s\n",
            "epoch 7, loss 0.34, time 6.08s\n",
            "epoch 8, loss 0.33, time 6.10s\n",
            "epoch 9, loss 0.32, time 6.09s\n",
            "epoch 10, loss 0.32, time 6.07s\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8UhBn_AaFL88",
        "colab_type": "text"
      },
      "source": [
        "### 10.3.6 应用词嵌入模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Dg3_LMfdFmyZ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "da604ac6-1184-4269-beaf-9dd066d4e0e6"
      },
      "source": [
        "def get_similar_tokens(query_token, k, embed):\n",
        "    W = embed.weight.data \n",
        "    x = W[token_to_idx[query_token]]\n",
        "    cos = torch.matmul(W, x) / (torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9).sqrt()\n",
        "    _, topk = torch.topk(cos, k=k+1)\n",
        "    topk = topk.cpu().numpy()\n",
        "    for i in topk[1:]:\n",
        "        print('cosine sim=%.3f: %s' % (cos[i], (idx_to_token[i])))\n",
        "\n",
        "get_similar_tokens('chip', 3, net[0])"
      ],
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "cosine sim=0.482: software\n",
            "cosine sim=0.475: armonk\n",
            "cosine sim=0.462: lavelle\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}